package services

import (
	"bufio"
	"bytes"
	"context"
	"encoding/base64"
	"github.com/google/gopacket"
	"github.com/google/gopacket/layers"
	"github.com/google/gopacket/pcap"
	"github.com/wailsapp/wails/v2/pkg/runtime"
	"log"
	"net/http"
	"procPacket/model"
	"procPacket/utils"
	"strconv"
	"strings"
	"sync"
	"time"
)

var (
	index = 0
	//NeedPorts   = make(map[string]int32) // port -> 来源pid
	NeedPorts   sync.Map
	SelectPid   []int32
	AllDevices  []pcap.Interface
	NeedDevices []string
	DeviceMap   = make(map[string]string) // name -> description
	CancelFunc  context.CancelFunc
)

func init() {
	AllDevices = utils.GetAllDevices()
	for _, dev := range AllDevices {
		NeedDevices = append(NeedDevices, dev.Name)
		DeviceMap[dev.Name] = dev.Description
	}
	go func() {
		for {
			NeedPorts = sync.Map{}
			for _, pid := range SelectPid {
				// 目前只看local addr
				for _, conn := range utils.GetNeedConnStats(pid) {
					//NeedPorts[strconv.Itoa(int(conn.Laddr.Port))] = pid
					NeedPorts.Store(strconv.Itoa(int(conn.Laddr.Port)), pid)
				}
			}
			//log.Printf("%+v\n", NeedPorts)
			time.Sleep(3 * time.Second)
		}
	}()
}

func GetPacket(ctx context.Context) context.CancelFunc {
	cancelCtx, cancelFunc := context.WithCancel(ctx)
	CancelFunc = cancelFunc
	for _, device := range NeedDevices {
		log.Println("start a go " + device)
		go func(device string, ctx context.Context) {
			handle, err := pcap.OpenLive(device, 1024, false, 30*time.Second)
			if err != nil {
				log.Fatal(err)
			}
			defer handle.Close()
			err = handle.SetBPFFilter("tcp")
			if err != nil {
				log.Fatal(err)
			}
			ps := gopacket.NewPacketSource(handle, handle.LinkType())
			var mPacket = model.Packet{}
			for {
				select {
				case packet := <-ps.Packets():
					ipLayer := packet.Layer(layers.LayerTypeIPv4)
					tcpLayer := packet.Layer(layers.LayerTypeTCP)
					if tcpLayer != nil {
						tcp, _ := tcpLayer.(*layers.TCP)
						mPacket.SrcPort = strings.Split(tcp.SrcPort.String(), "(")[0]
						mPacket.DstPort = strings.Split(tcp.DstPort.String(), "(")[0]
						mPacket.Window = tcp.Window
					} else {
						continue
					}
					pid, ok := NeedPorts.Load(mPacket.SrcPort)
					pid2, ok2 := NeedPorts.Load(mPacket.DstPort)
					if !(ok || ok2) {
						continue
					}
					if ok {
						mPacket.Pid = pid.(int32)
					} else if ok2 {
						mPacket.Pid = pid2.(int32)
					} else {
						continue
					}
					if ipLayer != nil {
						ip, _ := ipLayer.(*layers.IPv4)
						mPacket.SrcIp = ip.SrcIP.String()
						mPacket.DstIp = ip.DstIP.String()
						mPacket.Protocol = ip.Protocol.String()
					} else {
						continue
					}
					mPacket.Data = base64.StdEncoding.EncodeToString(packet.Data())
					mPacket.DataString = string(packet.Data())
					mPacket.DeviceName = device
					mPacket.DeviceDescription = DeviceMap[device]
					if strings.Contains(string(packet.Data()), "HTTP") {
						mPacket.Protocol = "HTTP"
						request, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(packet.ApplicationLayer().Payload())))
						if err != nil {
							//log.Println(err.Error())
						} else {
							//log.Printf("%+v\n", request)
							//log.Printf("%+v\n", request.Body)
							mPacket.Method = request.Method
						}
						//log.Printf("%+v\n", mPacket)
					}
					mPacket.Index = index
					index++
					//log.Println("package_receive ", index)
					runtime.EventsEmit(ctx, "package_receive", mPacket)
				//}
				case <-ctx.Done():
					log.Printf("%s done...\n", device)
					return
				}
			}
		}(device, cancelCtx)
	}
	return cancelFunc
}

func SetSelectPid(pids []int32) {
	SelectPid = pids
}

func SetNeedDevices(devs []string) {
	NeedDevices = devs
}
